//
// Created by Xing on 2022/3/10.
//

#include <algorithm>
#include <vector>

#include "caffe/filler.hpp"
#include "caffe/layers/instance_norm_layer.hpp"
#include "caffe/util/math_functions.hpp"
#include <cmath>

namespace caffe {
    template<typename Dtype>
    void InstanceNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*> &bottom,
                                              const vector<Blob<Dtype>*> &top) {
        InstanceNormParameter param = this->layer_param_.instance_normalize_param();
        channels_ = bottom[0]->shape(1);
        eps_ = param.eps();
        scale_bias_ = param.scale_bias(); // by default = false;
        if (param.has_scale_filler() || param.has_bias_filler()) // implicit set
            scale_bias_ = true;

        if (this->blobs_.size() > 0) {
            LOG(INFO) << "Skipping parameter initialization";
        } else {
            this->blobs_.resize(2);
            const vector<int> shape{channels_};
            // get weight bias
            if (scale_bias_) {
                this->blobs_[0].reset(new Blob<Dtype>(shape)); // scale
                this->blobs_[1].reset(new Blob<Dtype>(shape)); // bias
                FillerParameter scale_param(param.scale_filler());
                if (!param.has_scale_filler()) {
                    // Default to unit (1) filler for identity operation.
                    scale_param.set_type("constant");
                    scale_param.set_value(1.0f);
                }
                shared_ptr<Filler<Dtype> > scale_filler(GetFiller<Dtype>(scale_param));
                scale_filler->Fill(this->blobs_[0].get());

                FillerParameter bias_param(param.bias_filler());
                if (!param.has_bias_filler()) {
                    // Default to unit (0) filler for identity operation.
                    bias_param.set_type("constant");
                    bias_param.set_value(0.0f);
                }
                shared_ptr<Filler<Dtype>> bias_filler(GetFiller<Dtype>(bias_param));
                bias_filler->Fill(this->blobs_[1].get());
            }
        }
    }

    template<typename Dtype>
    void InstanceNormLayer<Dtype>::Reshape(const vector<Blob<Dtype>*> &bottom,
                                           const vector<Blob<Dtype>*> &top) {
        if (bottom[0]->num_axes() > 1)
            CHECK_EQ(bottom[0]->shape(1), channels_);
        top[0]->ReshapeLike(*bottom[0]);

        mean_.Reshape({bottom[0]->count(0, 2)});
        var_.Reshape({bottom[0]->count(0, 2)});
        inv_var_.Reshape({bottom[0]->count(0, 2)});

        ones_N_.Reshape({bottom[0]->shape(0)});
        caffe_set(ones_N_.count(), Dtype(1.0f), ones_N_.mutable_cpu_data());
        ones_C_.Reshape({bottom[0]->shape(1)});
        caffe_set(ones_C_.count(), Dtype(1.0f), ones_C_.mutable_cpu_data());
        ones_HW_.Reshape({bottom[0]->count(2)});
        caffe_set(ones_HW_.count(), Dtype(1.0f), ones_HW_.mutable_cpu_data());
        temp_NC_ .Reshape({bottom[0]->count(0, 2)});
        caffe_set(temp_NC_.count(), Dtype(1.0f), temp_NC_.mutable_cpu_data());

        temp_NCHW_.ReshapeLike(*bottom[0]);
        x_norm_.ReshapeLike(*bottom[0]);
    }

    template<typename Dtype>
    void InstanceNormLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {
        int top_size = top[0]->count();

        const auto *bottom_data = bottom[0]->cpu_data();
        auto *top_data = top[0]->mutable_cpu_data();

        // SUM(N * C) / H * W
        compute_mean_per_channel_cpu(bottom[0]->shape(0), bottom[0]->shape(1),
                                     bottom[0]->count(2), bottom_data, mean_.mutable_cpu_data());

        // Mean(N * C) --> N * C * H * W
        multicast_cpu(bottom[0]->shape(0), bottom[0]->shape(1),
                      bottom[0]->count(2), mean_.cpu_data(), temp_NCHW_.mutable_cpu_data());

        //  Y = X- EX
        if (bottom[0] != top[0])
            caffe_copy(top_size, bottom_data, top_data);

        caffe_axpy<Dtype>(top_size, Dtype(-1.0f), temp_NCHW_.cpu_data(), top_data);

        // compute variance E (X-EX)^2
        caffe_powx<Dtype>(top_size, top_data, Dtype(2.0f), temp_NCHW_.mutable_cpu_data());

        // variance Sum(C) Variance --> N * C * H * W --> C
        // mean sum -- variance
        compute_mean_per_channel_cpu(bottom[0]->shape(0), bottom[0]->shape(1),
                                     bottom[0]->count(2), temp_NCHW_.cpu_data(), var_.mutable_cpu_data());

        //  inv_var= ( eps+ variance)^(-0.5)
        caffe_add_scalar(bottom[0]->count(0, 2), eps_, var_.mutable_cpu_data());

        caffe_powx(bottom[0]->count(0, 2), var_.cpu_data(), Dtype(-0.5f), inv_var_.mutable_cpu_data());

        // invar C --> N * C --> N * C * H * W
        // X_norm = (X-EX) * inv_var
        multicast_cpu(bottom[0]->shape(0), bottom[0]->shape(1),
                      bottom[0]->count(2), inv_var_.cpu_data(), temp_NCHW_.mutable_cpu_data());

        caffe_mul(top_size, top_data, temp_NCHW_.cpu_data(), top_data);

        // -- STAGE 2:  Y = X_norm * scale[c] + shift[c]  -----------------
        if (scale_bias_) {
            // Y = X_norm * scale[c]
            const Blob<Dtype> &scale_data = *(this->blobs_[0]);
            multicast_cpu_v2(bottom[0]->shape(0), bottom[0]->shape(1),
                             bottom[0]->count(2), scale_data.cpu_data(), temp_NCHW_.mutable_cpu_data());

            caffe_mul(top_size, top_data, temp_NCHW_.cpu_data(), top_data);
            // Y = Y + shift[c]
            const Blob<Dtype> &shift_data = *(this->blobs_[1]);
            multicast_cpu_v2(bottom[0]->shape(0), bottom[0]->shape(1),
                             bottom[0]->count(2), shift_data.cpu_data(), temp_NCHW_.mutable_cpu_data());

            caffe_add(top_size, top_data, temp_NCHW_.mutable_cpu_data(), top_data);
        }
    }

    INSTANTIATE_CLASS(InstanceNormLayer);
    REGISTER_LAYER_CLASS(InstanceNorm);
}  // namespace caffe

